# Copyright 2020 ByteDance Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy

import numpy
import tensorflow as tf

from neurst.models import build_model
from neurst.utils.hparams_sets import get_hyper_parameters


def test_seq2seq():
    params = copy.deepcopy(get_hyper_parameters("transformer_toy")["model.params"])
    params["modality.source.dim"] = None
    params["modality.target.dim"] = None
    params["modality.source.timing"] = None
    params["modality.target.timing"] = None
    src_vocab_meta = dict(vocab_size=8, eos_id=7, bos_id=6, unk_id=5)
    trg_vocab_meta = dict(vocab_size=5, eos_id=4, bos_id=3, unk_id=2)
    parsed_inputs = {
        "src": tf.convert_to_tensor(
            [[0, 1, 1, 7], [1, 7, 7, 7]], tf.int64),
        "src_padding": tf.convert_to_tensor([[0, 0, 0, 0.], [0, 0, 1, 1.]], tf.float32),
        "trg_input": tf.convert_to_tensor([[3, 0, 1], [3, 2, 4]], tf.int32),
        "trg": tf.convert_to_tensor([[0, 1, 4], [2, 4, 4]], tf.int32),
        "trg_padding": tf.convert_to_tensor([[0, 0, 0.], [0, 0, 1.]], tf.float32),
    }

    model = build_model({"model.class": "transformer", "params": params},
                        src_meta=src_vocab_meta, trg_meta=trg_vocab_meta)
    _ = model(parsed_inputs, is_training=False)
    for w in model.trainable_weights:
        if "target_symbol_modality/shared/weights" in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[-0.29354253, -0.23483634, 0.25630027, -0.02696097, -0.5017841,
                  -0.01427859, 0.64076746, 0.10676116],
                 [-0.19711176, -0.20760003, -0.48422408, -0.0074994, -0.31429327,
                  0.00126553, -0.17251879, 0.29386985],
                 [0.38033593, -0.27076742, 0.2611575, 0.66763735, 0.5333196,
                  -0.52800345, -0.5451049, 0.5960151],
                 [-0.38007882, 0.47841036, 0.11322564, 0.3999585, -0.5566431,
                  -0.6169907, 0.5290351, -0.48975855],
                 [0.24198133, -0.1712935, -0.13487989, 0.03922045, -0.27576318,
                  0.15308863, 0.18018633, -0.49891895]]
            ))
        elif "target_symbol_modality/shared/bias" in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [-0.12844944, 0.70201373, 0.47467923, 0.17776501, -0.57099354]
            ))

        elif "input_symbol_modality/emb/weights" in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[-0.28932106, 0.04174006, 0.32917994, -0.01771283, -0.32744384,
                  0.4569562, -0.4678616, 0.00129563],
                 [-0.4225411, -0.59086347, -0.0714885, 0.51049083, -0.5401395,
                  0.3862279, -0.53301275, 0.30440414],
                 [-0.19314134, 0.09168714, -0.5058322, -0.42353332, 0.5074443,
                  0.03560042, 0.26724458, 0.33088684],
                 [-0.5153856, -0.38528442, -0.20011288, 0.4713922, 0.13764167,
                  -0.18305543, -0.43612635, 0.5469119],
                 [-0.54713076, 0.32743508, 0.38312858, -0.5525645, 0.591134,
                  0.1707223, 0.15555906, -0.42832434],
                 [-0.5138424, -0.21375301, -0.46360433, -0.6103692, -0.50063866,
                  0.24583805, -0.5414497, -0.01820809],
                 [0.3424672, -0.38758308, 0.05292654, 0.10646945, -0.09475929,
                  0.5051289, 0.16801137, 0.03101033],
                 [-0.10960919, 0.20824891, -0.02183038, -0.06829894, 0.48780817,
                  -0.18522224, 0.22240955, -0.21551234]]
            ))
        elif ("TransformerEncoder/layer_0/self_attention_prepost_wrapper/"
              "self_attention/output_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.31903958, 0.41097552, 0.35810417, 0.4822548, 0.5416022,
                  0.02170408, 0.32241964, -0.54333895],
                 [0.5172518, 0.14113712, 0.44610864, -0.43546906, 0.49923056,
                  0.23127198, 0.310534, 0.3501947],
                 [0.5763511, -0.4778806, 0.3984726, 0.13659805, -0.05111057,
                  0.4764889, 0.05881822, -0.37829816],
                 [-0.33052838, -0.3291011, -0.59498054, 0.2654276, -0.5715602,
                  0.01546502, 0.04336095, 0.13782066],
                 [-0.32840976, -0.37728345, -0.49385822, -0.49648887, 0.4832974,
                  0.07143259, -0.17042065, 0.43592864],
                 [0.31292784, 0.01520997, 0.40785295, -0.12775904, 0.03555053,
                  -0.35662168, -0.5096859, 0.33710766],
                 [-0.36864457, 0.30672514, -0.4093505, -0.4461822, -0.41201153,
                  0.12536913, -0.3134546, -0.110695],
                 [0.50774044, 0.25777447, -0.18048626, -0.30132556, 0.3435768,
                  0.49845392, -0.21432358, -0.05989999]]
            ))
        elif ("TransformerEncoder/layer_0/self_attention_prepost_wrapper"
              "/self_attention/qkv_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.24556783, -0.10109329, -0.18614727, -0.35749245, 0.07600775,
                  -0.30707863, 0.11381295, -0.21648653, -0.32361317, 0.04083973,
                  0.00325903, 0.17453268, -0.38458756, -0.12808836, -0.30286443,
                  -0.28138128, 0.3906658, 0.2981322, 0.1857591, -0.10963717,
                  0.13652292, -0.42696893, -0.32537884, -0.17609134],
                 [0.00684109, 0.40689567, 0.22115704, -0.22863819, -0.22739726,
                  0.3783851, -0.37274942, -0.21842214, -0.22557294, -0.07110339,
                  0.3998916, -0.0190008, 0.27676454, -0.19919433, 0.2616723,
                  -0.41782314, -0.2811813, -0.3239204, 0.13037983, 0.10246852,
                  -0.14516768, -0.13455674, -0.20624177, 0.30381766],
                 [-0.36161476, 0.3910825, 0.11459449, -0.19012608, -0.1930628,
                  -0.09042051, 0.04295725, -0.09732714, -0.27065122, -0.1735073,
                  -0.11896703, -0.2472982, -0.24865237, 0.0597097, -0.23580097,
                  -0.402398, -0.04311767, -0.14832097, 0.25989994, -0.03256506,
                  -0.3376931, 0.35324004, 0.01395121, -0.28511477],
                 [0.33902344, -0.16730174, 0.2059339, -0.0727739, -0.24657604,
                  0.01062217, -0.21674432, 0.11485538, 0.23314235, -0.30125052,
                  0.32238856, -0.2450316, 0.03718695, -0.276408, 0.23392966,
                  -0.07773718, 0.3429754, -0.19731745, 0.37889633, 0.34160677,
                  0.05413216, 0.03037485, -0.3704696, 0.28774682],
                 [-0.41983247, 0.1209394, -0.03301042, 0.20576969, -0.28212637,
                  -0.25600716, -0.09135348, -0.19963133, -0.1577549, -0.13313296,
                  -0.02467829, 0.39583513, -0.21820472, 0.10990372, -0.42987105,
                  -0.3018305, -0.33682942, -0.04609847, -0.0978007, -0.35909522,
                  0.35906085, -0.38199574, -0.02560577, 0.4065493],
                 [-0.39747363, -0.21786559, 0.4050602, 0.29975984, -0.03308517,
                  -0.05114299, 0.23231843, -0.42908302, -0.09869319, -0.3929163,
                  0.14195767, -0.04656759, 0.2699246, 0.1801227, 0.14472279,
                  -0.4127182, -0.4004244, -0.10136119, 0.4069151, 0.3895177,
                  -0.15835935, -0.13569432, -0.38402212, -0.16429195],
                 [-0.1027582, 0.02577147, 0.39300737, -0.10241205, -0.4256417,
                  0.33153847, -0.0325374, -0.13393977, 0.05391803, -0.20058648,
                  -0.25471783, 0.08702543, -0.09722248, 0.02570912, -0.279415,
                  0.04044545, -0.27716812, 0.19806209, 0.22688219, -0.30685633,
                  0.00624642, 0.14048973, -0.2722684, 0.39918897],
                 [-0.19335268, 0.38261148, 0.30058286, 0.25313148, 0.27221575,
                  0.37937936, 0.1745182, 0.14772478, -0.27204615, 0.38106957,
                  0.36370513, 0.16695651, -0.40864846, -0.14278689, 0.34316894,
                  0.41350552, -0.42566204, -0.22474506, -0.18263665, 0.11183658,
                  -0.12859318, 0.02102521, -0.1425604, 0.11403349]]
            ))
        elif "TransformerEncoder/layer_0/ffn_prepost_wrapper/ffn/dense1/kernel" in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.38400275, 0.11049551, 0.19255298, 0.45194864, -0.02915239,
                  0.31835914, -0.3630433, 0.11081731, -0.02559841, 0.38685995],
                 [0.42969477, 0.2031151, 0.5144137, -0.07936049, 0.31766498,
                  0.5058452, 0.44898677, 0.16335446, 0.3953011, 0.4361714],
                 [0.04883695, -0.56701475, 0.09635973, -0.50472724, -0.1245037,
                  -0.37787604, -0.21818402, 0.16247958, -0.14578387, -0.41005552],
                 [0.13449967, 0.05132979, -0.5468524, -0.17919052, 0.01128888,
                  0.09902984, 0.23214585, -0.08920336, 0.55008626, 0.50717974],
                 [-0.1738911, -0.24616602, 0.18358463, -0.11349753, 0.15567136,
                  -0.45293823, 0.29155105, 0.49324703, 0.01795202, 0.255095],
                 [-0.23427847, -0.47127584, 0.47553408, 0.17752594, -0.4635463,
                  -0.05620468, -0.5232727, 0.39365137, -0.38289946, 0.05879569],
                 [0.25051618, 0.26999742, -0.24446961, 0.03792298, 0.01752973,
                  -0.41537094, 0.44205165, -0.11403576, -0.3807313, -0.23905703],
                 [-0.33319134, -0.47972375, 0.526567, 0.34260195, -0.01981884,
                  -0.02918285, -0.02829635, -0.5294999, 0.563005, 0.05829275]]
            ))
        elif "TransformerEncoder/layer_0/ffn_prepost_wrapper/ffn/dense2/kernel" in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.2340402, -0.10299325, 0.03826767, -0.00556576, 0.16777557,
                  -0.48395926, -0.21232244, 0.540642],
                 [-0.5568968, -0.24176422, 0.17467064, 0.3885694, 0.4655552,
                  -0.15393665, -0.4475953, -0.3920542],
                 [0.07647067, 0.2340278, -0.13460535, -0.34944105, 0.0448994,
                  0.35044646, -0.5451377, -0.39633614],
                 [0.16932797, 0.4503368, -0.48202705, -0.05000919, -0.3586144,
                  0.07879007, -0.47378975, -0.5153118],
                 [-0.4939471, -0.49206224, 0.33845508, -0.5155843, -0.07823312,
                  0.30778152, -0.14456016, -0.49705222],
                 [0.23529834, 0.39454746, -0.3392254, -0.31639364, 0.39075094,
                  0.55396605, 0.03435838, 0.3698709],
                 [-0.01985615, -0.14796564, -0.04773241, 0.1197027, 0.02213496,
                  0.24299401, 0.23960501, 0.45019186],
                 [-0.1280163, -0.11015153, 0.19618726, -0.55472195, -0.45635638,
                  -0.15839794, 0.28029287, 0.00874251],
                 [-0.18816125, -0.16009945, -0.14088362, 0.41544813, -0.20673174,
                  0.01065433, 0.03431308, -0.17323837],
                 [-0.30255532, 0.5155908, 0.23801541, 0.46748185, -0.42719585,
                  -0.49111396, 0.3950773, -0.27734205]]
            ))
        elif ("TransformerEncoder/layer_1/self_attention_prepost_wrapper/"
              "self_attention/output_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.42618555, -0.09034979, -0.23231441, -0.43777925, 0.45706886,
                  -0.59829664, 0.4076385, 0.23851973],
                 [0.05634236, 0.17002487, -0.08434552, 0.31617373, 0.03625625,
                  0.5910465, -0.6076178, -0.2687951],
                 [-0.14819229, -0.27034125, 0.2064324, -0.19751346, 0.21064728,
                  0.29283345, 0.23406833, 0.10519284],
                 [0.31500018, -0.4173568, -0.00893188, -0.26349744, 0.15418595,
                  -0.399687, -0.22666007, -0.6096985],
                 [-0.1316917, -0.36008307, -0.43647486, 0.10060841, -0.16681895,
                  -0.35083786, 0.26369733, -0.12640283],
                 [0.5797457, -0.59191436, -0.57749504, -0.54847366, -0.20692074,
                  0.4509862, -0.01773721, 0.1577],
                 [0.4081785, 0.5246411, -0.5135473, -0.23788959, -0.26497075,
                  -0.23121881, 0.35329401, 0.42074102],
                 [-0.46347424, 0.56120163, -0.2939334, 0.2747522, 0.56474787,
                  0.5690356, 0.19718772, -0.09090984]]
            ))
        elif ("TransformerEncoder/layer_1/self_attention_prepost_wrapper/"
              "self_attention/qkv_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[-0.04137263, 0.4122521, 0.07474831, -0.42290825, 0.01918331,
                  -0.0367808, 0.20840707, -0.19495474, -0.36590886, 0.12961635,
                  -0.42065755, 0.21793994, 0.15142605, 0.05064282, 0.3728448,
                  0.4305556, -0.19640265, -0.13260049, 0.41600618, -0.30270132,
                  0.28347465, -0.2972833, -0.22339822, -0.4168277],
                 [-0.42739302, 0.0618836, 0.30369553, -0.01105291, -0.2725063,
                  0.18827173, -0.07787129, 0.29560563, 0.11015823, -0.2556733,
                  0.3800684, 0.20649257, -0.03591421, 0.35618058, -0.39821273,
                  0.0430806, -0.37791556, -0.05824929, 0.29839876, 0.06364432,
                  -0.28479278, 0.37887844, -0.19407392, -0.24432379],
                 [-0.2754909, 0.21458694, 0.2540948, -0.06881586, 0.2752199,
                  -0.42529625, -0.18034342, -0.2641306, 0.08662507, -0.19239433,
                  -0.01936874, -0.42879313, 0.2515919, 0.05828688, -0.35050425,
                  0.19613442, 0.10595468, -0.06380415, 0.14495179, -0.26701403,
                  0.33381835, 0.11836699, 0.10901466, -0.19060831],
                 [-0.08439368, -0.1435681, -0.38354927, 0.29710206, 0.39372167,
                  0.29005793, 0.22486511, 0.10090873, -0.27392572, 0.12495866,
                  -0.38597837, 0.37385282, -0.15801638, 0.34403047, 0.05333185,
                  -0.19141418, -0.43146238, -0.09826642, 0.39207748, 0.02903318,
                  -0.0447951, -0.140995, 0.12605539, -0.27343658],
                 [-0.14746845, 0.26028237, -0.14068425, -0.02098277, -0.34208745,
                  -0.36879313, 0.3709258, -0.18287906, -0.38343272, 0.01450509,
                  0.33475187, 0.19835839, -0.02770916, -0.19535396, 0.24291894,
                  0.40508488, 0.1228393, 0.35743287, -0.31064862, -0.2738737,
                  -0.08634344, 0.17820784, 0.2404854, -0.21379128],
                 [0.32416382, 0.23761937, -0.2714734, 0.01659575, 0.12218228,
                  0.08210799, 0.39640966, 0.04924238, -0.10259542, -0.42907375,
                  -0.0455032, -0.04837993, -0.25596887, -0.16206014, -0.40621698,
                  0.10435715, 0.2919118, -0.3757009, 0.12669042, -0.06276929,
                  0.08691922, 0.01388359, 0.2609237, 0.14391366],
                 [-0.37109214, 0.08338836, 0.41613457, 0.09220138, 0.14755598,
                  -0.3846822, -0.32047546, -0.11989969, 0.04941088, 0.3733643,
                  -0.22359593, 0.01040426, -0.13329476, 0.03873777, 0.25831434,
                  0.04679212, -0.34217292, -0.23983024, 0.36969563, 0.35033616,
                  0.05077001, 0.32096437, 0.2942368, -0.06438693],
                 [0.04559416, 0.3110021, 0.10469446, -0.09112707, -0.21549596,
                  -0.08703595, 0.19566664, -0.27119064, -0.31012705, -0.3460493,
                  0.20034257, 0.34390983, -0.30513322, 0.30294558, 0.15193626,
                  -0.13466576, -0.15653265, -0.04085603, -0.04187199, -0.3818181,
                  0.35413423, -0.11948714, 0.12659273, 0.33491793]]
            ))
        elif "TransformerEncoder/layer_1/ffn_prepost_wrapper/ffn/dense1/kernel" in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.16969907, 0.538725, -0.47220635, -0.39862955, 0.5590445,
                  -0.57381415, 0.55189013, -0.1241096, -0.1750552, 0.07282209],
                 [-0.04967839, -0.29894733, 0.48699057, -0.26354527, -0.11624891,
                  0.00518572, 0.06982511, 0.21453673, 0.52487314, 0.50849414],
                 [-0.29642364, -0.1552884, 0.37976956, -0.09915912, 0.21726537,
                  0.09865189, -0.3579256, 0.2882828, -0.5435448, 0.34120053],
                 [-0.16734263, -0.30591854, -0.48299694, 0.36032963, 0.3083346,
                  0.32025862, -0.0323239, -0.03540909, 0.19812691, 0.56041396],
                 [0.08146846, -0.4032659, 0.43548548, -0.505157, 0.29625255,
                  0.20229155, -0.2784496, -0.16810659, 0.00465661, -0.46176454],
                 [0.25855982, -0.44527876, -0.05630809, 0.44814825, 0.4672327,
                  0.07238638, 0.23067313, -0.31218028, 0.5251508, -0.46993703],
                 [0.36020505, 0.48421, 0.04297256, 0.07937276, 0.39654619,
                  0.08334208, -0.44477332, 0.15238297, -0.14505252, 0.5653666],
                 [0.17023551, 0.05648631, -0.5590816, -0.4013535, 0.00587964,
                  -0.41224653, -0.5178517, -0.44671488, -0.13213646, -0.16264695]]
            ))
        elif "TransformerEncoder/layer_1/ffn_prepost_wrapper/ffn/dense2/kernel" in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.08363676, 0.443043, -0.20048293, 0.5397774, -0.08774236,
                  0.51563346, 0.44048393, 0.05069989],
                 [-0.39923793, 0.27010256, 0.3120396, 0.15755522, 0.09888685,
                  0.09209388, 0.23463911, -0.20073885],
                 [0.39725387, 0.3083284, 0.04398292, -0.5214203, 0.1661511,
                  0.32843602, 0.535144, -0.30733716],
                 [-0.52302945, 0.09949869, -0.20001906, -0.4563232, 0.10634673,
                  -0.0867821, 0.2130729, 0.15544009],
                 [-0.16209882, 0.47079623, -0.36366975, -0.39391387, -0.13728681,
                  0.36896384, -0.1279692, -0.24792987],
                 [0.4540763, 0.43117046, 0.34526706, -0.44267043, -0.2801833,
                  0.09091371, 0.31143135, -0.46842438],
                 [-0.3841617, 0.3537798, -0.456631, -0.07963607, 0.18825197,
                  0.34253138, 0.00311643, -0.39619297],
                 [0.19681883, 0.02538323, 0.49230504, -0.54670614, -0.16814995,
                  0.26320857, -0.2583875, -0.45845556],
                 [0.10035574, -0.33199033, -0.06377029, -0.38322705, 0.18576187,
                  0.30481344, 0.30165493, -0.56413436],
                 [0.13095653, 0.5693759, -0.34928244, -0.00579017, 0.45523894,
                  0.45559692, -0.4755445, -0.5578483]]
            ))
        elif ("TransformerDecoder/layer_0/self_attention_prepost_wrapper/"
              "self_attention/output_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.41402858, -0.2655511, 0.21687216, -0.05976683, -0.24678236,
                  -0.55986947, -0.10050869, 0.36443913],
                 [-0.31218863, -0.08026814, -0.3503775, -0.2830528, 0.19764078,
                  0.07665694, -0.22002375, 0.58338326],
                 [0.36593944, 0.47826117, -0.3155697, 0.22407556, -0.2367759,
                  0.5582003, -0.01308447, 0.02416301],
                 [-0.5932773, 0.54228276, 0.07887, -0.36850107, -0.57571995,
                  0.52597564, -0.12966257, -0.06494093],
                 [-0.5416004, -0.4324838, 0.5738513, 0.23318034, -0.5079873,
                  0.44698435, 0.1884408, -0.4100449],
                 [-0.41715717, -0.47995192, 0.27436692, 0.45396346, -0.32279193,
                  -0.52322745, -0.22139937, 0.46218258],
                 [0.04606843, -0.48210734, -0.09731799, 0.1566211, 0.3348605,
                  0.53798, 0.2066397, 0.17096424],
                 [0.5118193, -0.26824263, 0.0513528, -0.22810039, -0.02520913,
                  -0.25055912, -0.21125275, 0.01200509]]
            ))
        elif ("TransformerDecoder/layer_0/self_attention_prepost_wrapper/"
              "self_attention/qkv_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[-0.24635717, -0.35896713, 0.39586702, -0.03602478, 0.27512792,
                  0.23269245, 0.29596278, -0.13523233, 0.3122929, 0.01758271,
                  0.19535479, 0.42010358, 0.3058509, -0.27858323, -0.09621406,
                  -0.28900337, -0.13637415, 0.2554522, -0.13693246, 0.23890129,
                  0.22502461, -0.00342193, -0.37178487, 0.04001474],
                 [-0.06197342, 0.28338936, 0.10876206, 0.21770415, -0.2445885,
                  -0.37382, 0.24960616, -0.28366768, 0.33277413, 0.24190459,
                  0.28501043, 0.2390792, -0.21722354, -0.09839588, -0.07514569,
                  0.08434585, -0.17455393, -0.39285085, 0.3604456, -0.04403484,
                  0.17325982, 0.266789, 0.27641353, 0.2629675],
                 [0.31777444, -0.18994613, 0.07876977, 0.19285682, -0.3603885,
                  -0.07359949, 0.39663008, 0.12972179, 0.32373634, -0.28222823,
                  0.07523808, 0.06840143, 0.2784874, -0.32616594, -0.37903282,
                  0.11678198, -0.2441357, -0.15710688, -0.00175741, -0.40035915,
                  -0.09226942, 0.08680966, 0.25157234, 0.00786397],
                 [-0.06718335, -0.21293627, 0.23377934, -0.07398105, -0.04577821,
                  0.4012753, -0.36116257, 0.27832034, 0.20620236, -0.15069339,
                  0.16214707, -0.42465132, 0.25478825, -0.08184978, 0.35768852,
                  -0.12693104, -0.1273953, -0.3078432, 0.33522883, 0.34014687,
                  -0.08295268, -0.36013618, -0.08690733, -0.07324457],
                 [-0.0609462, 0.06251469, -0.04659629, 0.3167083, -0.02005619,
                  0.32234064, 0.35482922, -0.0772118, 0.3867505, 0.3833268,
                  -0.2319926, -0.417385, -0.38126078, 0.37261078, 0.0596388,
                  0.09162065, -0.23212992, -0.25532508, -0.3144799, 0.28181675,
                  0.01341996, 0.19811288, -0.21834192, -0.39427295],
                 [-0.13712531, 0.2572454, 0.2866812, 0.10211042, 0.06285053,
                  -0.3894317, -0.04404226, -0.39091605, -0.16874191, 0.08648756,
                  -0.30481267, 0.16437915, -0.23644, 0.07409009, -0.39548072,
                  0.35895494, 0.03730175, 0.4324384, -0.2938407, 0.38754657,
                  -0.3012539, -0.11363283, -0.28678095, -0.1598432],
                 [0.00581551, 0.14337441, -0.04939786, 0.11189356, 0.31094417,
                  0.01152644, 0.27642164, -0.09637818, -0.09211436, -0.16248363,
                  0.39744857, -0.4116622, -0.05383742, 0.36805126, 0.14875862,
                  0.1099014, 0.371321, -0.41085994, -0.18536153, 0.20604655,
                  -0.13384223, -0.14118773, -0.1283133, -0.39778396],
                 [-0.01566258, -0.4047187, -0.37664068, -0.19478449, 0.09347895,
                  -0.36023095, 0.21561489, -0.33089578, -0.2711009, -0.03610542,
                  -0.3796572, 0.306676, 0.27266768, 0.22641936, -0.30573982,
                  -0.18740533, -0.34311372, -0.22143514, -0.41552392, 0.42686227,
                  -0.1086936, 0.03383243, -0.15354112, -0.26625448]]
            ))
        elif ("TransformerDecoder/layer_0/encdec_attention_prepost_wrapper/"
              "encdec_attention/output_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[-0.1279256, -0.2419937, -0.5854874, 0.57889825, -0.5364065,
                  0.23631936, -0.49949092, -0.30174196],
                 [0.00957078, -0.49736997, -0.4237002, 0.0218854, -0.17279565,
                  -0.5768471, -0.18963015, 0.10355526],
                 [0.11799914, -0.292151, -0.36201292, -0.266887, 0.15741825,
                  -0.11333472, -0.03553617, 0.0177772],
                 [-0.39861536, 0.17891657, -0.22581154, 0.07609612, -0.34631196,
                  0.26317436, 0.41848058, 0.27004486],
                 [-0.37255478, -0.20311174, 0.5176136, -0.54658747, 0.23746693,
                  -0.03754926, 0.04889613, -0.41350323],
                 [0.2125783, -0.536155, -0.19549471, 0.36943835, 0.24639928,
                  0.07458866, 0.28700095, -0.36578485],
                 [-0.2657523, -0.2433975, -0.56110847, -0.2861476, -0.19445652,
                  0.21033949, -0.30730212, 0.40339154],
                 [0.31910568, 0.0055629, 0.03742898, -0.5246967, 0.35341913,
                  0.3554458, 0.5315719, 0.13093019]]
            ))
        elif ("TransformerDecoder/layer_0/encdec_attention_prepost_wrapper/"
              "encdec_attention/q_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.51545686, 0.0990485, 0.29777205, -0.28110617, -0.26308733,
                  0.2853282, -0.31212774, 0.30727994],
                 [0.5417524, 0.12922692, 0.3285774, -0.02031326, 0.08855647,
                  -0.00454164, 0.02288318, 0.39679402],
                 [-0.09431475, -0.2857204, -0.29803967, 0.28193474, 0.26423824,
                  -0.31383288, -0.25300246, -0.01376557],
                 [0.12011659, 0.55608934, -0.01549584, -0.48516896, -0.44164532,
                  -0.16531923, 0.44081384, -0.54160094],
                 [-0.3235532, 0.55393785, 0.2136209, 0.08658487, 0.02760661,
                  -0.24593821, 0.23313332, -0.03452164],
                 [-0.3659288, -0.55161166, -0.5393511, -0.08154327, 0.47045785,
                  -0.2545886, 0.603108, 0.17091894],
                 [-0.41575676, -0.24764174, 0.33940715, -0.49895483, 0.14083397,
                  0.05251276, 0.09940594, 0.30034548],
                 [-0.5737393, -0.45933425, -0.02393657, -0.12469256, -0.24861848,
                  0.48773366, -0.38281965, 0.06820959]]
            ))
        elif ("TransformerDecoder/layer_0/encdec_attention_prepost_wrapper/"
              "encdec_attention/kv_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.3608706, -0.16985774, 0.04648876, 0.17727554, -0.32050753,
                  0.15797412, 0.32923543, -0.19890809, -0.09514797, 0.09165347,
                  -0.08939207, 0.1240828, 0.12936771, -0.48354328, 0.09154546,
                  0.06640613],
                 [0.26706707, -0.07982218, -0.28840077, -0.15964293, 0.44048142,
                  0.10202003, -0.19224763, 0.4643935, -0.49145675, 0.28452814,
                  -0.28381097, -0.1886301, 0.3626212, 0.48149836, -0.40126383,
                  0.01182055],
                 [0.48325312, 0.13339198, 0.08147466, 0.01886415, 0.410465,
                  -0.24456823, -0.04810286, 0.3934772, -0.42655325, -0.12829137,
                  0.47660065, -0.3516115, -0.11145651, -0.02882326, -0.38462532,
                  0.16618061],
                 [0.28752756, -0.09809136, -0.06697667, -0.22326052, 0.33962095,
                  -0.06639445, -0.06673455, 0.03969002, 0.03658247, 0.2047621,
                  0.41957307, -0.27317607, -0.1286192, -0.1504153, -0.08790445,
                  -0.27503848],
                 [0.40700352, -0.13340664, 0.48895872, 0.2091173, -0.4158994,
                  0.42262292, 0.45204484, 0.31661832, -0.16831684, -0.43958127,
                  0.40800595, 0.4231466, 0.2662462, 0.4360491, -0.05090606,
                  0.41579437],
                 [-0.1475159, 0.05631268, 0.43667984, 0.22322762, 0.24188244,
                  -0.2558658, 0.05513358, -0.44220436, 0.47696745, 0.30288208,
                  0.35236907, -0.46022415, -0.2354449, -0.2824862, 0.1728853,
                  0.00242376],
                 [-0.19901407, -0.17316806, 0.34936786, 0.05637395, -0.08862174,
                  0.15412652, 0.14734995, -0.02360725, 0.20836592, 0.10715961,
                  0.21128082, -0.01028705, 0.27915657, 0.00645471, 0.34993672,
                  0.46311176],
                 [0.40358865, -0.12622762, 0.11518359, 0.18501854, 0.01984668,
                  0.45133805, 0.1628021, -0.17971015, -0.16342247, -0.22245312,
                  -0.26478374, 0.160591, 0.4486302, -0.19825566, 0.04753971,
                  0.12643707]]
            ))
        elif "TransformerDecoder/layer_0/ffn_prepost_wrapper/ffn/dense1/kernel" in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[-0.06720757, 0.55263114, -0.37820417, -0.18817183, 0.4967841,
                  0.5301496, 0.44765162, 0.17229474, 0.02037746, -0.38267606],
                 [0.22507912, 0.08319503, -0.42931908, 0.21395624, 0.4883101,
                  0.02807504, -0.10768619, -0.47498938, 0.04546309, 0.51695967],
                 [-0.32582825, -0.15555033, -0.35707173, -0.00528497, 0.11157733,
                  -0.4079039, -0.20309281, -0.2786939, -0.00143158, -0.45975608],
                 [0.0592798, -0.297385, 0.35483736, 0.2347272, -0.3477485,
                  0.26017946, -0.17936438, 0.44473732, -0.28609666, -0.14807671],
                 [-0.3869655, -0.5571348, -0.38598603, -0.41803488, 0.43944812,
                  -0.3425563, 0.25616652, -0.0285089, -0.0508908, -0.54111296],
                 [-0.44107342, -0.5042058, 0.5217055, -0.34677118, 0.475623,
                  0.18002027, -0.44467062, 0.05279869, -0.30962384, -0.45696396],
                 [-0.11149651, 0.3705026, -0.5126401, 0.06722903, 0.22575969,
                  -0.23028824, 0.2056027, -0.39192414, -0.25298402, 0.4379238],
                 [0.14971024, 0.42451167, 0.37757248, -0.3726549, -0.17506334,
                  -0.46460786, -0.02499455, 0.13482589, -0.12902525, -0.19523734]]
            ))
        elif "TransformerDecoder/layer_0/ffn_prepost_wrapper/ffn/dense2/kernel" in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[-0.44042823, 0.45197666, -0.23344472, 0.45998847, -0.17414865,
                  0.4641745, 0.4826498, -0.1315352],
                 [0.41060203, -0.211938, 0.08441406, 0.2431289, -0.38785285,
                  0.35918987, 0.07967973, -0.19248444],
                 [0.17039984, 0.01675391, -0.19650468, 0.10323095, -0.02209324,
                  -0.24919105, 0.16697949, 0.11663049],
                 [0.17856616, -0.20257097, 0.3182906, 0.1157276, -0.45809188,
                  -0.13065588, -0.5293646, -0.04682791],
                 [-0.19376227, -0.5453018, -0.0328182, -0.5452718, 0.26869357,
                  0.13249546, 0.08024281, 0.11003381],
                 [-0.23756227, -0.29575357, -0.50909173, -0.05765748, -0.0089184,
                  0.489527, 0.0540911, -0.20290643],
                 [-0.43088597, -0.03776497, -0.07004839, 0.3612193, 0.2700277,
                  0.3630551, -0.35514504, 0.0078786],
                 [-0.3577707, 0.5772364, -0.45408776, 0.04695731, 0.12955356,
                  0.08641922, -0.06749266, -0.22854668],
                 [0.3447554, -0.50018543, -0.4450423, -0.345627, 0.4853915,
                  -0.38487256, -0.23583022, 0.41968864],
                 [0.5223309, 0.34582454, 0.24228495, 0.4505279, 0.00524783,
                  0.33739161, 0.1729073, 0.46376586]]
            ))
        elif ("TransformerDecoder/layer_1/self_attention_prepost_wrapper/"
              "self_attention/output_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[-0.446121, 0.3940925, 0.49132103, 0.17713946, 0.5267928,
                  0.33675808, 0.44058722, -0.43157172],
                 [-0.23504972, 0.1617412, 0.2769773, -0.26133326, 0.24745297,
                  -0.0520584, 0.07277727, -0.5577672],
                 [-0.29327726, 0.2514521, 0.32843417, 0.5675153, -0.5442774,
                  -0.24685362, -0.3434327, 0.29523093],
                 [0.25270784, -0.20233193, -0.13284832, 0.28228354, -0.4794641,
                  0.12789321, -0.39262465, 0.04397899],
                 [-0.60009784, 0.45697302, -0.32597286, -0.03012645, 0.01654047,
                  -0.3432645, -0.52298236, -0.45876426],
                 [-0.19784635, 0.01058447, -0.58458495, -0.5126084, -0.5655494,
                  -0.41740847, -0.19458848, -0.10731643],
                 [-0.5258043, -0.61217636, -0.47019628, -0.3324889, -0.39158016,
                  0.36343306, -0.36333203, -0.22256723],
                 [0.24401158, -0.13122407, 0.5713683, -0.6086697, 0.12495714,
                  0.25823617, -0.09232122, 0.5900312]]))
        elif ("TransformerDecoder/layer_1/self_attention_prepost_wrapper/"
              "self_attention/qkv_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.27481452, 0.0116879, 0.36719075, -0.40440372, -0.1954606,
                  -0.2300574, -0.04979965, 0.15613547, 0.32280543, 0.3273132,
                  0.3912786, 0.4046168, -0.30568987, -0.33408988, 0.15435639,
                  -0.08106208, 0.32937118, -0.34070706, 0.0546439, -0.24983734,
                  0.0207603, 0.08601627, -0.27549195, 0.20412138],
                 [0.14348724, 0.18185094, 0.167887, -0.3021682, 0.2971051,
                  0.07907161, -0.37291273, -0.26329404, 0.24814805, -0.00783703,
                  -0.1134795, 0.25298938, -0.0403159, 0.09382078, -0.25310278,
                  0.42588016, -0.0232923, -0.23894715, 0.26872233, -0.3017637,
                  0.35517278, 0.4123756, 0.35715845, -0.2612683],
                 [0.251209, 0.30718777, -0.09743929, 0.37868705, -0.3782806,
                  -0.10440734, -0.20695278, -0.42843944, 0.11033848, 0.4274877,
                  0.21334943, 0.3301848, 0.31885192, 0.3971382, -0.09676668,
                  0.22961542, 0.28164133, 0.28870395, 0.24603716, 0.13049194,
                  -0.26271415, 0.3598245, 0.17889282, -0.09679371],
                 [0.18480167, -0.423978, 0.28147706, 0.20233068, 0.07700345,
                  0.3950176, 0.16953233, -0.2767653, -0.0351927, -0.3871778,
                  -0.10333872, -0.38401458, 0.08614203, -0.09418231, 0.1258482,
                  0.41503003, -0.23736389, 0.3829991, 0.20315519, -0.0506267,
                  0.02750155, 0.18088666, 0.32316545, 0.07156941],
                 [-0.3365289, 0.07633492, 0.18811491, 0.12218675, -0.01712888,
                  0.11047456, 0.36789885, 0.07453135, 0.35507998, 0.32413712,
                  0.06988475, -0.316629, -0.09560555, -0.3577586, 0.11743674,
                  -0.1154238, 0.40550312, -0.28373045, -0.28391486, 0.22130796,
                  0.19461158, 0.34828517, 0.3402731, 0.42168418],
                 [0.22959384, -0.09466672, 0.13875905, 0.06585011, -0.08454975,
                  -0.25139913, 0.24867311, -0.19710684, -0.38250047, 0.05279905,
                  0.09058633, 0.05691019, -0.43189391, -0.00754103, -0.42296854,
                  -0.17274147, -0.1439153, -0.16499841, 0.4218262, 0.27872702,
                  0.269519, -0.284347, 0.00676736, -0.24074432],
                 [-0.43105984, -0.18570966, -0.25307292, -0.19746126, 0.11514279,
                  0.101432, -0.12518859, 0.10440406, -0.42490405, 0.05715063,
                  -0.2929991, 0.2661244, -0.12404522, 0.06171378, -0.15130952,
                  0.29441395, -0.41733328, 0.08141616, -0.34677923, -0.05524972,
                  0.18937346, -0.41702378, -0.06657425, 0.27120963],
                 [0.07061633, 0.23987249, 0.22944674, 0.08817294, 0.22188488,
                  -0.37523416, -0.3636308, 0.26619443, 0.05310896, -0.3865527,
                  -0.0594418, 0.10325739, 0.14090309, -0.02832022, 0.09751496,
                  -0.0530881, -0.04750797, -0.32113245, 0.25775167, -0.2249531,
                  0.17214248, -0.20723793, 0.05858463, -0.1042015]]))
        elif ("TransformerDecoder/layer_1/encdec_attention_prepost_wrapper/"
              "encdec_attention/output_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[-0.07463336, 0.0563764, 0.26746285, 0.58845574, 0.37224877,
                  0.22249967, -0.24321106, -0.48173416],
                 [-0.30540833, 0.24408221, -0.06326765, -0.11097288, 0.10069352,
                  -0.04288429, -0.44742495, 0.166543],
                 [0.14135772, -0.26862615, -0.50849557, 0.5784133, -0.40443277,
                  0.51631385, -0.07799548, 0.28732932],
                 [-0.09749961, 0.40039545, -0.06118071, -0.15212688, 0.34009832,
                  0.5772465, 0.48222512, -0.25559646],
                 [-0.37269944, -0.15007514, 0.11866188, -0.0120635, -0.0109489,
                  -0.60186726, -0.28244707, 0.32835752],
                 [0.559184, 0.29157156, -0.35879636, 0.24650383, 0.5976046,
                  -0.15556344, -0.11127496, -0.3011105],
                 [0.5442193, -0.20431828, 0.36724424, -0.4528572, 0.10426587,
                  0.11822385, -0.05441982, 0.07673579],
                 [-0.37118763, -0.24179482, -0.47427145, -0.17455658, 0.46202105,
                  0.24439615, -0.40861088, 0.2468313]]))
        elif ("TransformerDecoder/layer_1/encdec_attention_prepost_wrapper/"
              "encdec_attention/q_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.527518, -0.12114212, -0.40808892, 0.56731755, 0.2572146,
                  0.31378293, 0.20443302, -0.5630253],
                 [0.6023007, -0.08801287, -0.55323726, -0.49235207, 0.18328917,
                  -0.30462766, 0.4235236, -0.14947698],
                 [0.05836785, -0.32457548, -0.5583779, 0.17587304, 0.13842088,
                  -0.06220692, 0.05683714, -0.08522952],
                 [0.11454928, 0.57845205, 0.40677744, -0.32356766, -0.10824966,
                  0.5729895, 0.09953862, -0.49825168],
                 [-0.1325807, -0.5300193, -0.09281999, 0.23173773, -0.6103119,
                  -0.17548105, -0.40918946, -0.6055349],
                 [-0.26868924, -0.3843334, -0.14497796, 0.27963597, 0.38890153,
                  -0.36425418, 0.13343394, -0.17070243],
                 [-0.333827, 0.16035432, 0.17401373, -0.27310547, -0.23915032,
                  -0.3207253, -0.00749028, 0.4876346],
                 [0.3249125, -0.29519892, 0.49359602, -0.601942, -0.2753108,
                  -0.39890692, 0.04002428, 0.41897768]]))
        elif ("TransformerDecoder/layer_1/encdec_attention_prepost_wrapper/"
              "encdec_attention/kv_transform/kernel") in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[-0.00709212, -0.4091934, 0.26065922, 0.40150464, 0.26608098,
                  -0.3953911, -0.34422696, -0.06396389, -0.42655826, 0.35439622,
                  -0.20109999, -0.18769062, 0.0049336, -0.06693316, -0.4382484,
                  0.00183201],
                 [-0.02701962, 0.41023743, 0.02444375, 0.25569785, 0.04378641,
                  -0.37053585, 0.06267512, -0.06767642, -0.44424844, 0.2922008,
                  -0.44157362, -0.17749298, 0.17760682, -0.23238945, 0.3380952,
                  0.3164295],
                 [0.20117998, -0.13788939, 0.14445269, -0.31664026, 0.49193084,
                  0.08778274, -0.17864335, 0.16035259, -0.17492938, -0.04081237,
                  -0.4904747, -0.44932437, -0.19341111, -0.24871266, 0.38286912,
                  -0.06130087],
                 [0.2936057, -0.40730655, 0.18446267, 0.4097544, -0.0082581,
                  0.4734217, -0.46421993, -0.12871945, 0.22802174, 0.11106157,
                  0.26079726, -0.15126705, 0.40684378, -0.10213089, -0.24696314,
                  -0.02051508],
                 [-0.39994586, 0.16061008, 0.39812696, -0.3340621, -0.2076987,
                  0.20246327, -0.35409093, -0.4005847, -0.14170253, -0.21880937,
                  0.4408716, 0.22332358, -0.05699933, 0.17266095, 0.12294924,
                  0.38497412],
                 [-0.09543967, -0.34888685, -0.42740452, 0.1517607, -0.00862324,
                  -0.14572752, 0.47876465, -0.20919883, 0.32560217, 0.4249189,
                  -0.3933282, -0.22128391, -0.34623587, 0.14449048, -0.3857503,
                  -0.27833867],
                 [0.11869216, 0.05883706, -0.21212506, 0.49957561, 0.15783632,
                  -0.13721228, -0.21416295, -0.24007809, -0.294443, -0.16767824,
                  0.32042253, -0.31908023, 0.19871199, -0.43558514, -0.15620553,
                  0.11092794],
                 [-0.04378927, 0.35632384, 0.20292461, -0.27540374, 0.22871876,
                  -0.3632071, -0.40689313, 0.23316133, 0.37361324, -0.01663148,
                  -0.12638855, -0.32248807, -0.20867753, 0.2503358, -0.39324427,
                  -0.42774928]]))
        elif "TransformerDecoder/layer_1/ffn_prepost_wrapper/ffn/dense1/kernel" in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.24261475, -0.18643704, -0.01811624, 0.50356495, 0.01885831,
                  -0.2399435, 0.23692662, -0.10759905, -0.38264602, 0.1351049],
                 [0.21200335, -0.38962328, 0.29363745, 0.33583325, -0.24011764,
                  0.3635068, 0.4376179, 0.22551686, 0.5667083, -0.32501143],
                 [-0.49261767, 0.1927172, -0.0046156, -0.56056315, 0.47630668,
                  -0.31453356, 0.42453694, -0.32902807, 0.14415932, -0.5471806],
                 [-0.3316853, 0.13726503, -0.40464914, 0.28158778, 0.47430885,
                  -0.2569832, -0.5204258, -0.06528652, -0.5178821, 0.14735901],
                 [0.5328666, -0.12720194, 0.5184237, 0.411116, -0.3576244,
                  0.34368336, 0.16382056, -0.33515644, 0.17608005, 0.26269817],
                 [0.15965605, -0.25152162, -0.14534956, -0.2822171, 0.21284288,
                  0.05559379, 0.00327557, -0.4569926, -0.41969606, -0.56579554],
                 [-0.43731868, 0.32843924, 0.29003292, 0.1792146, -0.33100158,
                  -0.14961275, 0.12364352, -0.24879637, -0.39719564, 0.18711275],
                 [0.05891687, 0.47468245, -0.20260152, -0.3408, 0.5017748,
                  0.1640119, 0.22170597, -0.34292257, -0.31018573, -0.07051545]]))
        elif "TransformerDecoder/layer_1/ffn_prepost_wrapper/ffn/dense2/kernel" in w.name:
            tf.compat.v1.assign(w, numpy.array(
                [[0.01111823, -0.50019276, -0.33186796, 0.52229214, -0.4700832,
                  0.5457233, -0.21241191, 0.37699038],
                 [-0.28677762, -0.51243806, 0.52265644, -0.29745945, -0.35470137,
                  -0.5047183, 0.18846446, -0.17220777],
                 [-0.46509957, -0.00087285, -0.22127637, 0.4205513, -0.46209753,
                  -0.11040562, -0.0872128, 0.34856063],
                 [0.33827233, -0.31306413, -0.49311733, -0.49154714, -0.43418467,
                  0.11416692, 0.46271265, -0.1998105],
                 [0.05865157, -0.19406608, 0.2172538, -0.2894684, 0.2942767,
                  0.19267291, -0.31736228, -0.04036039],
                 [-0.49561584, -0.22174796, 0.15456653, -0.3632484, -0.4434304,
                  -0.30227244, -0.4071117, 0.4257239],
                 [0.2923094, 0.52523994, 0.22059155, 0.22125322, -0.30496007,
                  -0.20421728, -0.5533153, 0.28908247],
                 [-0.01375407, -0.42056724, -0.42731434, 0.14045459, -0.10852379,
                  -0.14693105, 0.3797375, 0.5360898],
                 [0.01416886, 0.2641362, -0.55372095, -0.17806509, -0.43746334,
                  -0.39878494, -0.5338729, -0.50196886],
                 [0.5125271, -0.31531927, -0.4611238, 0.38278532, -0.05637842,
                  0.23722917, -0.11141762, 0.44730043]]))
    outputs = model(parsed_inputs, is_training=False)
    assert numpy.sum((outputs.numpy() - numpy.array(
        [[[0.5600359, 1.0880388, 0.18974903, 1.8916442,
           0.8008492],
          [1.0519575, 1.1763976, 0.42835617, 0.5486565,
           0.7540616],
          [-0.09629793, 1.9182953, 0.4154176, -0.09568319,
           0.32058734]],

         [[0.68914187, 1.1119794, -0.5154613, 1.8321573,
           0.93645334],
          [-0.93543077, 1.9193068, 1.5986707, -1.1064756,
           -0.1642181],
          [0.2821706, 1.199893, -1.3765914, 0.02889553,
           1.045481]]])) ** 2) < 1e-9

    # test share / no share
    params = copy.deepcopy(params)
    params["modality.share_embedding_and_softmax_weights"] = True
    params["modality.share_source_target_embedding"] = True
    model = build_model({"class": "transformer", "params": params},
                        src_meta=src_vocab_meta, trg_meta=src_vocab_meta)
    _ = model(parsed_inputs, is_training=False)
    assert len(model._src_modality.trainable_weights) == 2
    for w in model._src_modality.trainable_weights:
        if "weights" in w.name:
            assert "shared_symbol_modality" in w.name
    assert len(model._trg_modality.trainable_weights) == 2
    for w in model._trg_modality.trainable_weights:
        if "weights" in w.name:
            assert "shared_symbol_modality" in w.name
    assert model._output_linear_layer is None

    params = copy.deepcopy(params)
    params["modality.share_embedding_and_softmax_weights"] = False
    params["modality.share_source_target_embedding"] = True
    model = build_model({"class": "transformer", "params": params},
                        src_meta=src_vocab_meta, trg_meta=src_vocab_meta)
    _ = model(parsed_inputs, is_training=False)
    assert len(model._trg_modality.trainable_weights) == 1
    assert "shared_symbol_modality" in model._trg_modality.trainable_weights[0].name
    assert len(model._src_modality.trainable_weights) == 1
    assert "shared_symbol_modality" in model._src_modality.trainable_weights[0].name
    assert model._output_linear_layer is not None

    params = copy.deepcopy(params)
    params["modality.share_embedding_and_softmax_weights"] = True
    params["modality.share_source_target_embedding"] = False
    model = build_model({"class": "transformer", "params": params},
                        src_meta=src_vocab_meta, trg_meta=src_vocab_meta)
    _ = model(parsed_inputs, is_training=False)
    assert len(model._trg_modality.trainable_weights) == 2
    for w in model._trg_modality.trainable_weights:
        if "weights" in w.name:
            assert "target_symbol_modality" in w.name
    assert len(model._src_modality.trainable_weights) == 1
    assert "input_symbol_modality" in model._src_modality.trainable_weights[0].name
    assert model._output_linear_layer is None

    params = copy.deepcopy(params)
    params["modality.share_embedding_and_softmax_weights"] = False
    params["modality.share_source_target_embedding"] = False
    model = build_model({"class": "transformer", "params": params},
                        src_meta=src_vocab_meta, trg_meta=src_vocab_meta)
    _ = model(parsed_inputs, is_training=False)
    assert len(model._trg_modality.trainable_weights) == 1
    assert "target_symbol_modality" in model._trg_modality.trainable_weights[0].name
    assert len(model._src_modality.trainable_weights) == 1
    assert "input_symbol_modality" in model._src_modality.trainable_weights[0].name
    assert model._output_linear_layer is not None


if __name__ == "__main__":
    test_seq2seq()
